#include "ref.h"

riscv_status ref_mat_mult_f32(const riscv_matrix_instance_f32 *pSrcA,
                            const riscv_matrix_instance_f32 *pSrcB,
                            riscv_matrix_instance_f32 *pDst)
{
    uint32_t r, c, i, outR, outC, innerSize;
    float32_t sum;

    outR = pSrcA->numRows;
    outC = pSrcB->numCols;
    innerSize = pSrcA->numCols;

    for (r = 0; r < outR; r++) {
        for (c = 0; c < outC; c++) {
            sum = 0;

            for (i = 0; i < innerSize; i++) {
                sum += pSrcA->pData[r * innerSize + i] *
                       pSrcB->pData[i * outC + c];
            }

            pDst->pData[r * outC + c] = sum;
        }
    }

    return RISCV_MATH_SUCCESS;
}

riscv_status ref_mat_mult_q31(const riscv_matrix_instance_q31 *pSrcA,
                            const riscv_matrix_instance_q31 *pSrcB,
                            riscv_matrix_instance_q31 *pDst)
{
    uint32_t r, c, i, outR, outC, innerSize;
    q63_t sum;

    outR = pSrcA->numRows;
    outC = pSrcB->numCols;
    innerSize = pSrcA->numCols;

    for (r = 0; r < outR; r++) {
        for (c = 0; c < outC; c++) {
            sum = 0;

            for (i = 0; i < innerSize; i++) {
                sum += (q63_t)(pSrcA->pData[r * innerSize + i]) *
                       pSrcB->pData[i * outC + c];
            }

            pDst->pData[r * outC + c] = ref_sat_q31(sum >> 31);
        }
    }

    return RISCV_MATH_SUCCESS;
}

riscv_status ref_mat_mult_q15(const riscv_matrix_instance_q15 *pSrcA,
                            const riscv_matrix_instance_q15 *pSrcB,
                            riscv_matrix_instance_q15 *pDst)
{
    uint32_t r, c, i, outR, outC, innerSize;
    q63_t sum;

    outR = pSrcA->numRows;
    outC = pSrcB->numCols;
    innerSize = pSrcA->numCols;

    for (r = 0; r < outR; r++) {
        for (c = 0; c < outC; c++) {
            sum = 0;

            for (i = 0; i < innerSize; i++) {
                sum += (q31_t)(pSrcA->pData[r * innerSize + i]) *
                       pSrcB->pData[i * outC + c];
            }

            pDst->pData[r * outC + c] = ref_sat_q15(sum >> 15);
        }
    }

    return RISCV_MATH_SUCCESS;
}
